/*
  Copyright (C) 2006 Pedro Felzenszwalb

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with this program; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
*/

#ifndef SEGMENT_IMAGE
#define SEGMENT_IMAGE

#include <cstdlib>
#include "image.h"
#include "misc.h"
#include "filter.h"
#include "segment-graph.h"

// random color
rgb random_rgb(){
  rgb c;
  //double r;

  c.r = (uchar)random();
  c.g = (uchar)random();
  c.b = (uchar)random();

  return c;
}

// dissimilarity measure between pixels
static inline float diff(image<float> *r, image<float> *g, image<float> *b,
			 int x1, int y1, int x2, int y2) {
  return sqrt(square(imRef(r, x1, y1)-imRef(r, x2, y2)) +
	      square(imRef(g, x1, y1)-imRef(g, x2, y2)) +
	      square(imRef(b, x1, y1)-imRef(b, x2, y2)));
}

/**
 * Method to compare the dissimilairy between to RGB-D pixels
 *
 * @param r R image
 * @param g G image
 * @param b B image
 * @param d D image
 * @param x1 first x coord
 * @param y1 first y coord
 * @param x2 second x coord
 * @param y2 second y coord
 * @param wc Weight for color channels
 * @param wd Weight for depth channels
 *
 * @return The dissimilarity
 */
static inline float depth_diff(image<float> *r, image<float> *g,
                               image<float> *b, image<float> *d,
                               int x1, int y1, int x2, int y2,
                               float wr,float wg, float wb, float wd) {
  return sqrt(wr*square(imRef(r, x1, y1)-imRef(r, x2, y2)) +
              wg*square(imRef(g, x1, y1)-imRef(g, x2, y2)) +
              wb*square(imRef(b, x1, y1)-imRef(b, x2, y2)) +
              wd*exp(square(imRef(d, x1, y1)-imRef(d, x2, y2))));
}

/*
 * Segment an image
 *
 * Returns a color image representing the segmentation.
 *
 * im: image to segment.
 * sigma: to smooth the image.
 * c: constant for treshold function.
 * min_size: minimum component size (enforced by post-processing stage).
 * num_ccs: number of connected components in the segmentation.
 */
image<rgb> *segment_image(image<rgb> *im, float sigma, float c, int min_size,
			  int *num_ccs) {
  int width = im->width();
  int height = im->height();

  image<float> *r = new image<float>(width, height);
  image<float> *g = new image<float>(width, height);
  image<float> *b = new image<float>(width, height);

  // smooth each color channel
  for (int y = 0; y < height; y++) {
    for (int x = 0; x < width; x++) {
      imRef(r, x, y) = imRef(im, x, y).r;
      imRef(g, x, y) = imRef(im, x, y).g;
      imRef(b, x, y) = imRef(im, x, y).b;
    }
  }
  image<float> *smooth_r = smooth(r, sigma);
  image<float> *smooth_g = smooth(g, sigma);
  image<float> *smooth_b = smooth(b, sigma);
  delete r;
  delete g;
  delete b;

  // build graph
  edge *edges = new edge[width*height*4];
  int num = 0;
  for (int y = 0; y < height; y++) {
    for (int x = 0; x < width; x++) {
      if (x < width-1) {
	edges[num].a = y * width + x;
	edges[num].b = y * width + (x+1);
	edges[num].w = diff(smooth_r, smooth_g, smooth_b, x, y, x+1, y);
	num++;
      }

      if (y < height-1) {
	edges[num].a = y * width + x;
	edges[num].b = (y+1) * width + x;
	edges[num].w = diff(smooth_r, smooth_g, smooth_b, x, y, x, y+1);
	num++;
      }

      if ((x < width-1) && (y < height-1)) {
	edges[num].a = y * width + x;
	edges[num].b = (y+1) * width + (x+1);
	edges[num].w = diff(smooth_r, smooth_g, smooth_b, x, y, x+1, y+1);
	num++;
      }

      if ((x < width-1) && (y > 0)) {
	edges[num].a = y * width + x;
	edges[num].b = (y-1) * width + (x+1);
	edges[num].w = diff(smooth_r, smooth_g, smooth_b, x, y, x+1, y-1);
	num++;
      }
    }
  }
  delete smooth_r;
  delete smooth_g;
  delete smooth_b;

  // segment
  universe *u = segment_graph(width*height, num, edges, c);

  // post process small components
  for (int i = 0; i < num; i++) {
    int a = u->find(edges[i].a);
    int b = u->find(edges[i].b);
    if ((a != b) && ((u->size(a) < min_size) || (u->size(b) < min_size)))
      u->join(a, b);
  }
  delete [] edges;
  *num_ccs = u->num_sets();

  image<rgb> *output = new image<rgb>(width, height);

  // pick random colors for each component
  rgb *colors = new rgb[width*height];

  // Added by MWF
  int *segmentids = new int[width*height];
  int segmentcounter = 1;
  // end MWF addition

  for (int i = 0; i < width*height; i++) {
    colors[i] = random_rgb();
    segmentids[i] = 0;
  }

  for (int y = 0; y < height; y++) {
    for (int x = 0; x < width; x++) {
      int comp = u->find(y * width + x);
      imRef(output, x, y) = colors[comp];
      // Added by MWF, 9/22/08
      // If we encounter a new comp / segment, record
      // its segmentid and increment the segmentid counter
      if (segmentids[comp] == 0)
      {
        segmentids[comp] = segmentcounter;
        segmentcounter++;
      }
      imRef(output, x, y).idx = segmentids[comp];
      // end MWF addition
    }
  }

  delete [] colors;
  delete u;

  return output;
}

/*
 * Segment an image
 *
 * Returns a color image representing the segmentation.
 *
 * im: image to segment.
 * sigma: to smooth the image.
 * c: constant for treshold function.
 * min_size: minimum component size (enforced by post-processing stage).
 * num_ccs: number of connected components in the segmentation.
 * wc: weight on each color channel
 * wd: weight on the depth channel
 */
image<rgb> *segment_image(image<rgb> *color_im, image<float> *depth_im,
                          float sigma, float c, int min_size, int *num_ccs,
                          float wr, float wg, float wb, float wd) {
  int width = color_im->width();
  int height = color_im->height();

  image<float> *r = new image<float>(width, height);
  image<float> *g = new image<float>(width, height);
  image<float> *b = new image<float>(width, height);
  image<float> *d = new image<float>(width, height);

  // smooth each color channel
  for (int y = 0; y < height; y++) {
    for (int x = 0; x < width; x++) {
      imRef(r, x, y) = imRef(color_im, x, y).r;
      imRef(g, x, y) = imRef(color_im, x, y).g;
      imRef(b, x, y) = imRef(color_im, x, y).b;
      imRef(d, x, y) = imRef(depth_im, x, y);
    }
  }
  image<float> *smooth_r = smooth(r, sigma);
  image<float> *smooth_g = smooth(g, sigma);
  image<float> *smooth_b = smooth(b, sigma);
  image<float> *smooth_d = smooth(d, sigma);
  delete r;
  delete g;
  delete b;
  delete d;

  // build graph
  edge *edges = new edge[width*height*4];
  int num = 0;
  for (int y = 0; y < height; y++) {
    for (int x = 0; x < width; x++) {
      if (x < width-1) {
	edges[num].a = y * width + x;
	edges[num].b = y * width + (x+1);
	edges[num].w = depth_diff(smooth_r, smooth_g, smooth_b, smooth_d, x, y,
                                  x+1, y, wr, wg, wb, wd);
	num++;
      }

      if (y < height-1) {
	edges[num].a = y * width + x;
	edges[num].b = (y+1) * width + x;
	edges[num].w = depth_diff(smooth_r, smooth_g, smooth_b, smooth_d, x, y,
                                  x, y+1, wr, wg, wb, wd);
	num++;
      }

      if ((x < width-1) && (y < height-1)) {
	edges[num].a = y * width + x;
	edges[num].b = (y+1) * width + (x+1);
	edges[num].w = depth_diff(smooth_r, smooth_g, smooth_b, smooth_d, x, y,
                                  x+1, y+1, wr, wg, wb, wd);
	num++;
      }

      if ((x < width-1) && (y > 0)) {
	edges[num].a = y * width + x;
	edges[num].b = (y-1) * width + (x+1);
	edges[num].w = depth_diff(smooth_r, smooth_g, smooth_b, smooth_d, x, y,
                                  x+1, y-1, wr, wg, wb, wd);
	num++;
      }
    }
  }
  delete smooth_r;
  delete smooth_g;
  delete smooth_b;
  delete smooth_d;

  // segment
  universe *u = segment_graph(width*height, num, edges, c);

  // post process small components
  for (int i = 0; i < num; i++) {
    int a = u->find(edges[i].a);
    int b = u->find(edges[i].b);
    if ((a != b) && ((u->size(a) < min_size) || (u->size(b) < min_size)))
      u->join(a, b);
  }
  delete [] edges;
  *num_ccs = u->num_sets();

  image<rgb> *output = new image<rgb>(width, height);

  // pick random colors for each component
  rgb *colors = new rgb[width*height];

  // Added by MWF
  int *segmentids = new int[width*height];
  int segmentcounter = 1;
  // end MWF addition

  for (int i = 0; i < width*height; i++) {
    colors[i] = random_rgb();
    segmentids[i] = 0;
  }

  for (int y = 0; y < height; y++) {
    for (int x = 0; x < width; x++) {
      int comp = u->find(y * width + x);
      imRef(output, x, y) = colors[comp];
      // Added by MWF, 9/22/08
      // If we encounter a new comp / segment, record
      // its segmentid and increment the segmentid counter
      if (segmentids[comp] == 0)
      {
        segmentids[comp] = segmentcounter;
        segmentcounter++;
      }
      imRef(output, x, y).idx = segmentids[comp];
      // end MWF addition
    }
  }

  delete [] colors;
  delete u;

  return output;
}

#endif
